{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VBOfRw7ifk8w"
      },
      "outputs": [],
      "source": [
        "# Copyright 2021 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w3LR4Lj8fk8x"
      },
      "source": [
        "# Feedback or issues?\n",
        "For any feedback or questions, please open an [issue](https://github.com/googleapis/python-aiplatform/issues)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mUMzY5W9fk8y"
      },
      "source": [
        "# Vertex SDK for Python: AutoML Image Classfication Training with Customer Managed Encryption Keys (CMEK) Example\n",
        "To use this Jupyter notebook, create a copy of the notebook in Colab and open it. You can run each step, or cell, and see its results. To run a cell, use Shift+Enter. Colab automatically displays the return value of the last line in each cell.\n",
        "\n",
        "This notebook demonstrate how to train an AutoML Image Classification model with CMEK. It will require you provide a bucket where the dataset will be stored.\n",
        "\n",
        "Note: you may incur charges for training, prediction, storage or usage of other GCP products in connection with testing this SDK."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lld3eeJUs5yM"
      },
      "source": [
        "# Install SDK\n",
        " \n",
        "After the SDK installation the kernel will be automatically restarted."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sBfZtR4X1Dr_"
      },
      "outputs": [],
      "source": [
        "!pip3 uninstall -y google-cloud-aiplatform\n",
        "!pip3 install --upgrade google-cloud-kms\n",
        "!pip3 install google-cloud-aiplatform\n",
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c0SNmTBeD2nV"
      },
      "source": [
        "### Enter your project and GCS bucket\n",
        "\n",
        "Enter your Project Id in the cell below. Then run the cell to make sure the Cloud SDK uses the right project for all the commands in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YcwsEwXPivBZ"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iqSQT6Z6bekX"
      },
      "outputs": [],
      "source": [
        "REGION = \"YOUR REGION\"  # e.g. us-central1\n",
        "MY_PROJECT = \"YOUR PROJECT ID\"\n",
        "MY_STAGING_BUCKET = \"gs://YOUR BUCKET\"  # bucket should be in same region as ucaip"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mRk9eoTm6Pyi"
      },
      "source": [
        "## Setting up Customer Managed Encryption Keys\n",
        "\n",
        "By default, Google Cloud automatically encrypts data when it is at rest using encryption keys managed by Google. If you have specific compliance or regulatory requirements related to the keys that protect your data, you can use customer-managed encryption keys (CMEK) for your training jobs.\n",
        "\n",
        "For more info on using CMEK on Vertex AI, please see: [https://cloud.google.com/vertex-ai/docs/general/cmek#before_you_begin](https://cloud.google.com/vertex-ai/docs/general/cmek#before_you_begin)\n",
        "\n",
        "You can create a key using the guide above or executing the the Notebook cells below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RD_Pvrg584X3"
      },
      "source": [
        "1. Register your application for Cloud Key Management Service (KMS) API in Google Cloud Platform at https://console.cloud.google.com/flows/enableapi?apiid=cloudkms.googleapis.com\n",
        "\n",
        "2. Create a key ring\n",
        "\n",
        "Create a key ring"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dxRZzbvQnZC7"
      },
      "outputs": [],
      "source": [
        "KEY_RING_ID = \"your_key_ring_name\"\n",
        "\n",
        "\n",
        "# Reference: https://cloud.google.com/kms/docs/samples/kms-create-key-ring\n",
        "def create_key_ring(project_id, location_id, id):\n",
        "    \"\"\"\n",
        "    Creates a new key ring in Cloud KMS\n",
        "\n",
        "    Args:\n",
        "        project_id (string): Google Cloud project ID (e.g. 'my-project').\n",
        "        location_id (string): Cloud KMS location (e.g. 'us-east1').\n",
        "        id (string): ID of the key ring to create (e.g. 'my-key-ring').\n",
        "\n",
        "    Returns:\n",
        "        KeyRing: Cloud KMS key ring.\n",
        "\n",
        "    \"\"\"\n",
        "\n",
        "    # Import the client library.\n",
        "    from google.cloud import kms\n",
        "\n",
        "    # Create the client.\n",
        "    client = kms.KeyManagementServiceClient()\n",
        "\n",
        "    # Build the parent location name.\n",
        "    location_name = f\"projects/{project_id}/locations/{location_id}\"\n",
        "\n",
        "    # Build the key ring.\n",
        "    key_ring = {}\n",
        "\n",
        "    # Call the API.\n",
        "    created_key_ring = client.create_key_ring(\n",
        "        request={\"parent\": location_name, \"key_ring_id\": id, \"key_ring\": key_ring}\n",
        "    )\n",
        "    print(\"Created key ring: {}\".format(created_key_ring.name))\n",
        "    return created_key_ring\n",
        "\n",
        "\n",
        "create_key_ring(project_id=MY_PROJECT, location_id=REGION, id=KEY_RING_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gCL1-IfFtWXl"
      },
      "source": [
        "Create a key"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LXcagdmSnYYW"
      },
      "outputs": [],
      "source": [
        "KEY_ID = \"your_key_id\"\n",
        "\n",
        "\n",
        "# Reference: https://cloud.google.com/kms/docs/samples/kms-create-key-symmetric-encrypt-decrypt\n",
        "def create_key_symmetric_encrypt_decrypt(project_id, location_id, key_ring_id, id):\n",
        "    \"\"\"\n",
        "    Creates a new symmetric encryption/decryption key in Cloud KMS.\n",
        "\n",
        "    Args:\n",
        "        project_id (string): Google Cloud project ID (e.g. 'my-project').\n",
        "        location_id (string): Cloud KMS location (e.g. 'us-east1').\n",
        "        key_ring_id (string): ID of the Cloud KMS key ring (e.g. 'my-key-ring').\n",
        "        id (string): ID of the key to create (e.g. 'my-symmetric-key').\n",
        "\n",
        "    Returns:\n",
        "        CryptoKey: Cloud KMS key.\n",
        "\n",
        "    \"\"\"\n",
        "\n",
        "    # Import the client library.\n",
        "    from google.cloud import kms\n",
        "\n",
        "    # Create the client.\n",
        "    client = kms.KeyManagementServiceClient()\n",
        "\n",
        "    # Build the parent key ring name.\n",
        "    key_ring_name = client.key_ring_path(project_id, location_id, key_ring_id)\n",
        "\n",
        "    # Build the key.\n",
        "    purpose = kms.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT\n",
        "    algorithm = (\n",
        "        kms.CryptoKeyVersion.CryptoKeyVersionAlgorithm.GOOGLE_SYMMETRIC_ENCRYPTION\n",
        "    )\n",
        "    key = {\n",
        "        \"purpose\": purpose,\n",
        "        \"version_template\": {\n",
        "            \"algorithm\": algorithm,\n",
        "        },\n",
        "    }\n",
        "\n",
        "    # Call the API.\n",
        "    created_key = client.create_crypto_key(\n",
        "        request={\"parent\": key_ring_name, \"crypto_key_id\": id, \"crypto_key\": key}\n",
        "    )\n",
        "    print(\"Created symmetric key: {}\".format(created_key.name))\n",
        "    return created_key\n",
        "\n",
        "\n",
        "create_key_symmetric_encrypt_decrypt(\n",
        "    project_id=MY_PROJECT, location_id=REGION, key_ring_id=KEY_RING_ID, id=KEY_ID\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3gKDBOqC8Gl5"
      },
      "source": [
        "Give permissions to key to the Vertex AI service account"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6QrRg08Vqfru"
      },
      "outputs": [],
      "source": [
        "# Reference: https://cloud.google.com/vertex-ai/docs/general/cmek#granting_permissions\n",
        "# Get the service account\n",
        "SERVICE_ACCOUNT = ! gcloud projects get-iam-policy {MY_PROJECT} \\\n",
        "  --flatten=\"bindings[].members\" \\\n",
        "  --format=\"table(bindings.members)\" \\\n",
        "  --filter=\"bindings.role:roles/aiplatform.serviceAgent\" \\\n",
        "  | grep -oP \"service-.+?@gcp-sa-aiplatform.iam.gserviceaccount.com\"\n",
        "SERVICE_ACCOUNT = SERVICE_ACCOUNT[0]\n",
        "\n",
        "print(f\"Service account is: {SERVICE_ACCOUNT}\")\n",
        "\n",
        "# Give permissions\n",
        "!gcloud kms keys add-iam-policy-binding {KEY_ID} \\\n",
        "  --keyring={KEY_RING_ID} \\\n",
        "  --location={REGION} \\\n",
        "  --project={MY_PROJECT} \\\n",
        "  --member=serviceAccount:{SERVICE_ACCOUNT} \\\n",
        "  --role=roles/cloudkms.cryptoKeyEncrypterDecrypter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ebAHZg2vlhXL"
      },
      "outputs": [],
      "source": [
        "# Create the full resource identifier for the created key\n",
        "ENCRYPTION_SPEC_KEY_NAME = f\"projects/{MY_PROJECT}/locations/{REGION}/keyRings/{KEY_RING_ID}/cryptoKeys/{KEY_ID}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aa_8wrqSkamz"
      },
      "source": [
        "## Initialize Vertex SDK for Python\n",
        "\n",
        "Initialize the *client* for Vertex AI\n",
        "\n",
        "All resources created during this Notebook run will encrypted with the encryption key created above.\n",
        "\n",
        "You can override the encryption key at each function call."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ohdgOs69kGNU"
      },
      "outputs": [],
      "source": [
        "from google.cloud import aiplatform\n",
        "\n",
        "aiplatform.init(\n",
        "    project=MY_PROJECT,\n",
        "    staging_bucket=MY_STAGING_BUCKET,\n",
        "    location=REGION,\n",
        "    encryption_spec_key_name=ENCRYPTION_SPEC_KEY_NAME,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "35QVNhACqcTJ"
      },
      "source": [
        "# Create Managed Image Dataset from CSV\n",
        "\n",
        "This section will create a managed Image dataset from the Flowers dataset. For more imformation on this dataset please visit https://www.tensorflow.org/datasets/catalog/tf_flowers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4OfCqaYRqcTJ"
      },
      "outputs": [],
      "source": [
        "IMPORT_FILE = (\n",
        "    \"gs://cloud-samples-data/vision/automl_classification/flowers/all_data_v2.csv\"\n",
        ")\n",
        "\n",
        "ds = aiplatform.ImageDataset.create(\n",
        "    display_name=\"flowers\",\n",
        "    gcs_source=[IMPORT_FILE],\n",
        "    import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification,\n",
        ")\n",
        "\n",
        "ds.resource_name"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6-bBqipfqcTS"
      },
      "source": [
        "# Launch a Training Job to Create a Model\n",
        "\n",
        "Train an AutoML Image Classification model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aA41rT_mb-rV"
      },
      "outputs": [],
      "source": [
        "job = aiplatform.AutoMLImageTrainingJob(\n",
        "    display_name=\"train-iris-automl-mbsdk-1\",\n",
        "    prediction_type=\"classification\",\n",
        "    multi_label=False,\n",
        "    model_type=\"CLOUD\",\n",
        "    base_model=None,\n",
        ")\n",
        "\n",
        "# This will take around half an hour to run\n",
        "model = job.run(\n",
        "    dataset=ds,\n",
        "    model_display_name=\"iris-classification-model-mbsdk\",\n",
        "    training_fraction_split=0.6,\n",
        "    validation_fraction_split=0.2,\n",
        "    test_fraction_split=0.2,\n",
        "    budget_milli_node_hours=8000,\n",
        "    disable_early_stopping=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5vhDsMJNqcTW"
      },
      "source": [
        "# Deploy Your Model\n",
        "\n",
        "Deploy your model, then wait until the model FINISHES deployment before proceeding to prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y9GH72wWqcTX"
      },
      "outputs": [],
      "source": [
        "endpoint = model.deploy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nIw1ifPuqcTb"
      },
      "source": [
        "# Predict on Endpoint\n",
        "- Take one sample from the data imported to the dataset\n",
        "- This sample will be encoded to base64 and passed to the endpoint for prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H23ISHdHVIZM"
      },
      "outputs": [],
      "source": [
        "test_item = !gsutil cat $IMPORT_FILE | head -n1\n",
        "test_item, test_label = str(test_item[0]).split(\",\")\n",
        "\n",
        "print(test_item, test_label)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TF_N0kqZU768"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "with tf.io.gfile.GFile(test_item, \"rb\") as f:\n",
        "    content = f.read()\n",
        "\n",
        "# The format of each instance should conform to the deployed model's prediction input schema.\n",
        "instances_list = [{\"content\": base64.b64encode(content).decode(\"utf-8\")}]\n",
        "\n",
        "prediction = endpoint.predict(instances=instances_list)\n",
        "\n",
        "prediction"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nWA3qocXfk82"
      },
      "source": [
        "# Undeploy Model from Endpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V1brMaO_fk82"
      },
      "outputs": [],
      "source": [
        "endpoint.undeploy_all()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "AI_Platform_(Unified)_SDK_AutoML_Image_Classification_Training_with_CMEK.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
